import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import numpy as np
from  PIL import Image
from tensorflow import keras
from tensorflow.keras import datasets,losses,layers,metrics,Sequential,optimizers

tf.random.set_seed(22)
np.random.seed(22)

def save_image(imgs, name):
    new_im = Image.new('L', (280, 280))

    index = 0
    for i in range(0, 280, 28):
        for j in range(0, 280, 28):
            im = imgs[index]
            im = Image.fromarray(im, mode='L')
            new_im.paste(im, (i, j))
            index += 1

    new_im.save(name)

h_dim = 20
batchsz = 128
lr = 1e-3
(x, y), (x_test, y_test) = datasets.fashion_mnist.load_data()

x_train, x_test = x.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batchsz * 5).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batchsz)

print(x_train.shape, y.shape)
print(x_test.shape, y_test.shape)

z_dim = 10

class VAE(keras.Model):

    def __init__(self):
        super(VAE, self).__init__()

        # Encoder
        self.fc1 = layers.Dense(128)
        self.fc2 = layers.Dense(z_dim) # 获得均值
        self.fc3 = layers.Dense(z_dim) # 获得方差

        # Decoder
        self.fc4 = layers.Dense(128)
        self.fc5 = layers.Dense(784)

    def encoder(self, x):
        h = tf.nn.relu(self.fc1(x))
        # 获得均值
        mu = self.fc2(h)
        # 获得方差
        log_var = self.fc3(h)

        return mu, log_var

    def decoder(self, z):

        out = tf.nn.relu(self.fc4(z))
        out = self.fc5(out)

        return out

    def reparameterize(self, mu, log_var):
        eps = tf.random.normal(log_var.shape)

        std = tf.exp(log_var)**0.5

        z = mu + std * eps

        return z

    def call(self, inputs, training=None):

        # [b,784] => [b,z_dim],[b,z_dim]
        mu, log_var = self.encoder(inputs)

        z = self.reparameterize(mu, log_var)

        x_hat = self.decoder(z)

        return x_hat, mu, log_var



model = VAE()
model.build(input_shape=(4, 784))
model.summary()

optimizer = optimizers.Adam(lr=lr)

for epoch in range(10):

    for step, x in enumerate(train_db):

        x = tf.reshape(x, [-1, 784])

        with tf.GradientTape() as tape:

            x_hat, mu, log_var = model(x)

            rec_loss = tf.losses.binary_crossentropy(x, x_hat, from_logits=True)
            rec_loss = tf.reduce_mean(rec_loss)


            # compute kl divergence(mu,var) ~ N)
            kl_div = -0.5 * (log_var + 1 - mu ** 2 - tf.exp(log_var))
            kl_div = tf.reduce_mean(kl_div) / x.shape[0]

            loss = rec_loss + 1. * kl_div

        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))


        if step % 100 == 0:
            print(epoch, step, 'rec_loss:', float(rec_loss), 'kl_div:', float(loss))


    # evaluation
    z = tf.random.normal((batchsz, z_dim))
    logits = model.decoder(z)
    x_hat = tf.sigmoid(logits)
    x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.
    x_hat = x_hat.astype(np.uint8)
    save_image(x_hat, 'vae_images/sample_epoch_%d.png' % epoch)

    x = next(iter(test_db))
    x = tf.reshape(x, [-1, 784])
    x_hat_logits, _, _ = model(x)
    x_hat = tf.sigmoid(x_hat_logits)
    x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.
    x_hat = x_hat.astype(np.uint8)
    save_image(x_hat, 'vae_images/rec_epoch_%d.png' % epoch)





